package Q4_12_Paths_with_Sum;
import java.util.HashMap;
import CtCILibrary.TreeNode;
public class QuestionB {
public static int countPathsWithSum(TreeNode root, int targetSum) {
return countPathsWithSum(root, targetSum, 0, new HashMap<Integer, Integer>());
}
public static int countPathsWithSum(TreeNode node, int targetSum, int runningSum, HashMap<Integer, Integer> pathCount) {
if (node == null) return 0; // Base case
runningSum += node.data;
/* Count paths with sum ending at the current node. */
int sum = runningSum - targetSum;
int totalPaths = pathCount.getOrDefault(sum, 0);
/* If runningSum equals targetSum, then one additional path starts at root. Add in this path.*/
if (runningSum == targetSum) {
totalPaths++;
}
/* Add runningSum to pathCounts. */
incrementHashTable(pathCount, runningSum, 1);
/* Count paths with sum on the left and right. */
totalPaths += countPathsWithSum(node.left, targetSum, runningSum, pathCount);
totalPaths += countPathsWithSum(node.right, targetSum, runningSum, pathCount);
incrementHashTable(pathCount, runningSum, -1); // Remove runningSum
return totalPaths;
}
public static void incrementHashTable(HashMap<Integer, Integer> hashTable, int key, int delta) {
int newCount = hashTable.getOrDefault(key, 0) + delta;
if (newCount == 0) { // Remove when zero to reduce space usage
hashTable.remove(key);
} else {
hashTable.put(key, newCount);
}
}
public static void main(String [] args) {
/*
TreeNode root = new TreeNode(5);
root.left = new TreeNode(3);
root.right = new TreeNode(1);
root.left.left = new TreeNode(-8);
root.left.right = new TreeNode(8);
root.right.left = new TreeNode(2);
root.right.right = new TreeNode(6);
root.right.left.left = new TreeNode(0);
System.out.println(countPathsWithSum(root, 0));
*/
/*TreeNode root = new TreeNode(-7);
root.left = new TreeNode(-7);
root.left.right = new TreeNode(1);
root.left.right.left = new TreeNode(2);
root.right = new TreeNode(7);
root.right.left = new TreeNode(3);
root.right.right = new TreeNode(20);
root.right.right.left = new TreeNode(0);
root.right.right.left.left = new TreeNode(-3);
root.right.right.left.left.right = new TreeNode(2);
root.right.right.left.left.right.left = new TreeNode(1);
System.out.println(countPathsWithSum(root, 0));*/
TreeNode root = new TreeNode(0);
root.left = new TreeNode(0);
root.right = new TreeNode(0);
root.right.left = new TreeNode(0);
root.right.left.right = new TreeNode(0);
root.right.right = new TreeNode(0);
System.out.println(countPathsWithSum(root, 0));
System.out.println(countPathsWithSum(root, 4));
}
}